{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pytorch_tabnet.tab_model import TabNetClassifier\n",
    "\n",
    "import torch\n",
    "from sklearn.preprocessing import LabelEncoder\n",
    "from sklearn.metrics import accuracy_score\n",
    "from sklearn.model_selection import train_test_split\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "np.random.seed(0)\n",
    "\n",
    "\n",
    "import os\n",
    "import wget\n",
    "from pathlib import Path\n",
    "import shutil\n",
    "import gzip\n",
    "\n",
    "from matplotlib import pyplot as plt\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Download ForestCoverType dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "url = \"https://archive.ics.uci.edu/ml/machine-learning-databases/covtype/covtype.data.gz\"\n",
    "dataset_name = 'forest-cover-type'\n",
    "tmp_out = Path('./data/'+dataset_name+'.gz')\n",
    "out = Path(os.getcwd()+'/data/'+dataset_name+'.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "out.parent.mkdir(parents=True, exist_ok=True)\n",
    "if out.exists():\n",
    "    print(\"File already exists.\")\n",
    "else:\n",
    "    print(\"Downloading file...\")\n",
    "    wget.download(url, tmp_out.as_posix())\n",
    "    with gzip.open(tmp_out, 'rb') as f_in:\n",
    "        with open(out, 'wb') as f_out:\n",
    "            shutil.copyfileobj(f_in, f_out)\n",
    "    \n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load data and split\n",
    "Same split as in original paper"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "target = \"Covertype\"\n",
    "\n",
    "bool_columns = [\n",
    "    \"Wilderness_Area1\", \"Wilderness_Area2\", \"Wilderness_Area3\",\n",
    "    \"Wilderness_Area4\", \"Soil_Type1\", \"Soil_Type2\", \"Soil_Type3\", \"Soil_Type4\",\n",
    "    \"Soil_Type5\", \"Soil_Type6\", \"Soil_Type7\", \"Soil_Type8\", \"Soil_Type9\",\n",
    "    \"Soil_Type10\", \"Soil_Type11\", \"Soil_Type12\", \"Soil_Type13\", \"Soil_Type14\",\n",
    "    \"Soil_Type15\", \"Soil_Type16\", \"Soil_Type17\", \"Soil_Type18\", \"Soil_Type19\",\n",
    "    \"Soil_Type20\", \"Soil_Type21\", \"Soil_Type22\", \"Soil_Type23\", \"Soil_Type24\",\n",
    "    \"Soil_Type25\", \"Soil_Type26\", \"Soil_Type27\", \"Soil_Type28\", \"Soil_Type29\",\n",
    "    \"Soil_Type30\", \"Soil_Type31\", \"Soil_Type32\", \"Soil_Type33\", \"Soil_Type34\",\n",
    "    \"Soil_Type35\", \"Soil_Type36\", \"Soil_Type37\", \"Soil_Type38\", \"Soil_Type39\",\n",
    "    \"Soil_Type40\"\n",
    "]\n",
    "\n",
    "int_columns = [\n",
    "    \"Elevation\", \"Aspect\", \"Slope\", \"Horizontal_Distance_To_Hydrology\",\n",
    "    \"Vertical_Distance_To_Hydrology\", \"Horizontal_Distance_To_Roadways\",\n",
    "    \"Hillshade_9am\", \"Hillshade_Noon\", \"Hillshade_3pm\",\n",
    "    \"Horizontal_Distance_To_Fire_Points\"\n",
    "]\n",
    "\n",
    "feature_columns = (\n",
    "    int_columns + bool_columns + [target])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train = pd.read_csv(out, header=None, names=feature_columns)\n",
    "\n",
    "n_total = len(train)\n",
    "\n",
    "# Train, val and test split follows\n",
    "# Rory Mitchell, Andrey Adinets, Thejaswi Rao, and Eibe Frank.\n",
    "# Xgboost: Scalable GPU accelerated learning. arXiv:1806.11248, 2018.\n",
    "\n",
    "train_val_indices, test_indices = train_test_split(\n",
    "    range(n_total), test_size=0.2, random_state=0)\n",
    "train_indices, valid_indices = train_test_split(\n",
    "    train_val_indices, test_size=0.2 / 0.6, random_state=0)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Simple preprocessing\n",
    "\n",
    "Label encode categorical features and fill empty cells."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "categorical_columns = []\n",
    "categorical_dims =  {}\n",
    "for col in train.columns[train.dtypes == object]:\n",
    "    print(col, train[col].nunique())\n",
    "    l_enc = LabelEncoder()\n",
    "    train[col] = train[col].fillna(\"VV_likely\")\n",
    "    train[col] = l_enc.fit_transform(train[col].values)\n",
    "    categorical_columns.append(col)\n",
    "    categorical_dims[col] = len(l_enc.classes_)\n",
    "\n",
    "for col in train.columns[train.dtypes == 'float64']:\n",
    "    train.fillna(train.loc[train_indices, col].mean(), inplace=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Define categorical features for categorical embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# This is a generic pipeline but actually no categorical features are available for this dataset\n",
    "\n",
    "unused_feat = []\n",
    "\n",
    "features = [ col for col in train.columns if col not in unused_feat+[target]] \n",
    "\n",
    "cat_idxs = [ i for i, f in enumerate(features) if f in categorical_columns]\n",
    "\n",
    "cat_dims = [ categorical_dims[f] for i, f in enumerate(features) if f in categorical_columns]\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Network parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "clf = TabNetClassifier(\n",
    "    n_d=64, n_a=64, n_steps=5,\n",
    "    gamma=1.5, n_independent=2, n_shared=2,\n",
    "    cat_idxs=cat_idxs,\n",
    "    cat_dims=cat_dims,\n",
    "    cat_emb_dim=1,\n",
    "    lambda_sparse=1e-4, momentum=0.3, clip_value=2.,\n",
    "    optimizer_fn=torch.optim.Adam,\n",
    "    optimizer_params=dict(lr=2e-2),\n",
    "    scheduler_params = {\"gamma\": 0.95,\n",
    "                     \"step_size\": 20},\n",
    "    scheduler_fn=torch.optim.lr_scheduler.StepLR, epsilon=1e-15\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if os.getenv(\"CI\", False):\n",
    "# Take only a subsample to run CI\n",
    "    X_train = train[features].values[train_indices][:1000,:]\n",
    "    y_train = train[target].values[train_indices][:1000]\n",
    "else:\n",
    "    X_train = train[features].values[train_indices]\n",
    "    y_train = train[target].values[train_indices]\n",
    "\n",
    "X_valid = train[features].values[valid_indices]\n",
    "y_valid = train[target].values[valid_indices]\n",
    "\n",
    "X_test = train[features].values[test_indices]\n",
    "y_test = train[target].values[test_indices]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "max_epochs = 100 if not os.getenv(\"CI\", False) else 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "from pytorch_tabnet.augmentations import ClassificationSMOTE\n",
    "aug = ClassificationSMOTE(p=0.2)\n",
    "\n",
    "clf.fit(\n",
    "    X_train=X_train, y_train=y_train,\n",
    "    eval_set=[(X_train, y_train), (X_valid, y_valid)],\n",
    "    eval_name=['train', 'valid'],\n",
    "    max_epochs=max_epochs, patience=100,\n",
    "    batch_size=16384, virtual_batch_size=256,\n",
    "    augmentations=aug\n",
    ") "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot losses\n",
    "plt.plot(clf.history['loss'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot accuracy\n",
    "plt.plot(clf.history['train_accuracy'])\n",
    "plt.plot(clf.history['valid_accuracy'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Predictions\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# To get final results you may need to use a mapping for classes \n",
    "# as you are allowed to use targets like [\"yes\", \"no\", \"maybe\", \"I don't know\"]\n",
    "\n",
    "preds_mapper = { idx : class_name for idx, class_name in enumerate(clf.classes_)}\n",
    "\n",
    "preds = clf.predict_proba(X_test)\n",
    "\n",
    "y_pred = np.vectorize(preds_mapper.get)(np.argmax(preds, axis=1))\n",
    "\n",
    "test_acc = accuracy_score(y_pred=y_pred, y_true=y_test)\n",
    "\n",
    "print(f\"BEST VALID SCORE FOR {dataset_name} : {clf.best_cost}\")\n",
    "print(f\"FINAL TEST SCORE FOR {dataset_name} : {test_acc}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# or you can simply use the predict method\n",
    "\n",
    "y_pred = clf.predict(X_test)\n",
    "test_acc = accuracy_score(y_pred=y_pred, y_true=y_test)\n",
    "print(f\"FINAL TEST SCORE FOR {dataset_name} : {test_acc}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Save and load Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# save state dict\n",
    "saved_filename = clf.save_model('test_model')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# define new model and load save parameters\n",
    "loaded_clf = TabNetClassifier()\n",
    "loaded_clf.load_model(saved_filename)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "loaded_preds = loaded_clf.predict_proba(X_test)\n",
    "loaded_y_pred = np.vectorize(preds_mapper.get)(np.argmax(loaded_preds, axis=1))\n",
    "\n",
    "loaded_test_acc = accuracy_score(y_pred=loaded_y_pred, y_true=y_test)\n",
    "\n",
    "print(f\"FINAL TEST SCORE FOR {dataset_name} : {loaded_test_acc}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "assert(test_acc == loaded_test_acc)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Global explainability : feat importance summing to 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "clf.feature_importances_"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Local explainability and masks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "explain_matrix, masks = clf.explain(X_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(1, 5, figsize=(20,20))\n",
    "\n",
    "for i in range(5):\n",
    "    axs[i].imshow(masks[i][:50])\n",
    "    axs[i].set_title(f\"mask {i}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# XGB"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_estimators = 1000 if not os.getenv(\"CI\", False) else 20"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "from xgboost import XGBClassifier\n",
    "\n",
    "clf_xgb = XGBClassifier(max_depth=8,\n",
    "    learning_rate=0.1,\n",
    "    n_estimators=n_estimators,\n",
    "    verbosity=0,\n",
    "    silent=None,\n",
    "    objective=\"multi:softmax\",\n",
    "    booster='gbtree',\n",
    "    n_jobs=-1,\n",
    "    nthread=None,\n",
    "    gamma=0,\n",
    "    min_child_weight=1,\n",
    "    max_delta_step=0,\n",
    "    subsample=0.7,\n",
    "    colsample_bytree=1,\n",
    "    colsample_bylevel=1,\n",
    "    colsample_bynode=1,\n",
    "    reg_alpha=0,\n",
    "    reg_lambda=1,\n",
    "    scale_pos_weight=1,\n",
    "    base_score=0.5,\n",
    "    random_state=0,\n",
    "    seed=None,)\n",
    "\n",
    "clf_xgb.fit(X_train, y_train,\n",
    "            eval_set=[(X_valid, y_valid)],\n",
    "            early_stopping_rounds=40,\n",
    "            verbose=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "preds_valid = np.array(clf_xgb.predict_proba(X_valid, ))\n",
    "valid_acc = accuracy_score(y_pred=np.argmax(preds_valid, axis=1) + 1, y_true=y_valid)\n",
    "print(valid_acc)\n",
    "\n",
    "preds_test = np.array(clf_xgb.predict_proba(X_test))\n",
    "test_acc = accuracy_score(y_pred=np.argmax(preds_test, axis=1) + 1, y_true=y_test)\n",
    "print(test_acc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.13"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
